# 剪裁图像和对应的标签
import os
import re
import sys
import cv2
import argparse
import numpy as np
import os.path as osp

from time import time
from multiprocessing import Pool
from shutil import get_terminal_size


class Timer:
    """A flexible Timer class."""
    def __init__(self, start=True, print_tmpl=None):
        self._is_running = False
        self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
        if start:
            self.start()

    @property
    def is_running(self):
        """bool: indicate whether the timer is running"""
        return self._is_running

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, type, value, traceback):
        print(self.print_tmpl.format(self.since_last_check()))
        self._is_running = False

    def start(self):
        """Start the timer."""
        if not self._is_running:
            self._t_start = time()
            self._is_running = True
        self._t_last = time()

    def since_start(self):
        """Total time since the timer is started.

        Returns (float): Time in seconds.
        """
        if not self._is_running:
            raise ValueError('timer is not running')
        self._t_last = time()
        return self._t_last - self._t_start

    def since_last_check(self):
        """Time since the last checking.

        Either :func:`since_start` or :func:`since_last_check` is a checking
        operation.

        Returns (float): Time in seconds.
        """
        if not self._is_running:
            raise ValueError('timer is not running')
        dur = time() - self._t_last
        self._t_last = time()
        return dur


class ProgressBar:
    """A progress bar which can print the progress."""
    def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
        self.task_num = task_num
        self.bar_width = bar_width
        self.completed = 0
        self.file = file
        if start:
            self.start()

    @property
    def terminal_width(self):
        width, _ = get_terminal_size()
        return width

    def start(self):
        if self.task_num > 0:
            self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
                            'elapsed: 0s, ETA:')
        else:
            self.file.write('completed: 0, elapsed: 0s')
        self.file.flush()
        self.timer = Timer()

    def update(self, num_tasks=1):
        assert num_tasks > 0
        self.completed += num_tasks
        elapsed = self.timer.since_start()
        if elapsed > 0:
            fps = self.completed / elapsed
        else:
            fps = float('inf')
        if self.task_num > 0:
            percentage = self.completed / float(self.task_num)
            eta = int(elapsed * (1 - percentage) / percentage + 0.5)
            msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
                  f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
                  f'ETA: {eta:5}s'

            bar_width = min(self.bar_width,
                            int(self.terminal_width - len(msg)) + 2,
                            int(self.terminal_width * 0.6))
            bar_width = max(2, bar_width)
            mark_width = int(bar_width * percentage)
            bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
            self.file.write(msg.format(bar_chars))
        else:
            self.file.write(
                f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
                f' {fps:.1f} tasks/s')
        self.file.flush()

def main_cut_images(args):
    opt = {}
    opt['n_thread'] = args.n_thread
    tag = args.tag
    opt['img_folder'] = osp.join(args.data_root, tag,'src')
    opt['save_img_folder'] = osp.join(args.data_root, '{}_sub'.format(tag))
    opt['label_folder'] = osp.join(args.data_root, tag,'gt')
    opt['save_label_folder'] = osp.join(args.data_root, '{}_labels_sub'.format(tag))

    opt['crop_size'] = args.crop_size
    opt['step'] = args.step
    opt['thresh_size'] = args.thresh_size

    cut_images(opt)
    cut_labels(opt)
   
def cut_images(opt):
    input_folder = opt['img_folder']
    save_folder = opt['save_img_folder']
    opt['save_folder'] = save_folder
    if not osp.exists(save_folder):
        os.makedirs(save_folder)
        print(f'mkdir {save_folder} ...')
    else:
        print(f'Folder {save_folder} already exists. Exit.')
        sys.exit(1)

    img_list = list(os.listdir(input_folder))
    img_list = [osp.join(input_folder, v) for v in img_list]

    prog_bar = ProgressBar(len(img_list))
    pool = Pool(opt['n_thread'])
    for path in img_list:
        pool.apply_async(worker,
                         args=(path, opt),
                         callback=lambda arg: prog_bar.update())
    pool.close()
    pool.join()
    print('All processes done.')

def cut_labels(opt):
    input_folder = opt['label_folder']
    save_folder = opt['save_label_folder']
    opt['save_folder'] = save_folder
    if not osp.exists(save_folder):
        os.makedirs(save_folder)
        print(f'mkdir {save_folder} ...')
    else:
        print(f'Folder {save_folder} already exists. Exit.')
        sys.exit(1)

    img_list = list(os.listdir(input_folder))
    img_list = [osp.join(input_folder, v) for v in img_list]

    prog_bar = ProgressBar(len(img_list))
    pool = Pool(opt['n_thread'])
    for path in img_list:
        pool.apply_async(worker,
                         args=(path, opt),
                         callback=lambda arg: prog_bar.update())
    pool.close()
    pool.join()
    print('All processes done.')

def worker(path, opt):
    crop_size = opt['crop_size']
    step = opt['step']
    thresh_size = opt['thresh_size']
    img_name, extension = osp.splitext(osp.basename(path))
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)

    if img.ndim == 2 or img.ndim == 3:
        h, w = img.shape[:2]
    else:
        raise ValueError(f'Image ndim should be 2 or 3, but got {img.ndim}')

    h_space = np.arange(0, h - crop_size + 1, step)
    if h - (h_space[-1] + crop_size) > thresh_size:
        h_space = np.append(h_space, h - crop_size)
    w_space = np.arange(0, w - crop_size + 1, step)
    if w - (w_space[-1] + crop_size) > thresh_size:
        w_space = np.append(w_space, w - crop_size)

    index = 0
    for x in h_space:
        for y in w_space:
            index += 1
            cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
            cv2.imwrite(osp.join(opt['save_folder'],f'{img_name}_s{index:03d}{extension}'), cropped_img)
    process_info = f'Processing {img_name} ...'
    return process_info

def parse_args():
    parser = argparse.ArgumentParser(
        description='Prepare DIV2K dataset',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--data-root',  default=r"UDD6",help='dataset root')
    parser.add_argument('--tag',  default=r"train",help='dataset root')
    parser.add_argument('--crop-size',
                        nargs='?',
                        default=1024,
                        help='cropped size for HR images')
    parser.add_argument('--step',
                        nargs='?',
                        default=1024,
                        help='step size for HR images')
    parser.add_argument('--thresh-size',
                        nargs='?',
                        default=0,
                        help='threshold size for HR images')
    parser.add_argument('--n-thread',
                        nargs='?',
                        default=20,
                        help='thread number when using multiprocessing')
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = parse_args()
    # extract subimages
    main_cut_images(args)
